import argparse
from typing import Dict, List

import yaml


# Different languages that are part of xwinograd.
# These correspond to dataset names (Subsets) on HuggingFace.
# A yaml file is generated by this script for each language.
LANGUAGES = ["en", "fr", "jp", "pt", "ru", "zh"]


def doc_to_text(doc: Dict) -> int:
    """
    Return index of the correct choice.

    Note: We are using the "multiple input" mode of the multiple-choice
        output-type, which means we use different contexts with the same target
        for the different choices, rather than the same context and different targets.
    """
    answer_to_num = {"1": 0, "2": 1}
    return answer_to_num[doc["answer"]]


def doc_to_target(doc: Dict) -> str:
    """
    Return the target completion.

    Note that this does not depend on the correct choice as we are using
    "multiple input" mode.
    """
    idx = doc["sentence"].index("_") + 1
    return doc["sentence"][idx:].strip()


def doc_to_choice(doc: Dict) -> List[str]:
    """Return the choices that will be used as contexts in "multiple input" mode."""
    idx = doc["sentence"].index("_")
    options = [doc["option1"], doc["option2"]]
    return [doc["sentence"][:idx] + opt for opt in options]


def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    for lang in LANGUAGES:
        file_name = f"xwinograd_{lang}.yaml"
        try:
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf-8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
                        "include": "xwinograd_common_yaml",
                        "dataset_name": lang,
                        "task": f"xwinograd_{lang}",
                    },
                    f,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=False,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory to write yaml files to"
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)


if __name__ == "__main__":
    main()
